close all
clear

folderName = '2023_02_14_2_2D2_qw_n2_6Er_scan_phase';
prefix = 'scan';
paramVar ='stat_phase_pi';

plot_figure = 1;
export_data = 0;

paramVar2 = 'evolution_dur';
evolution_dur = 10.875; 

col1 = 118;
col2 = 128; 

row1 = 89; %difference must be even
row2 = 112; %starting site must be in the center
centersites = [(row2-row1+1)/2 (row2-row1+1)/2+1];
row_start = row1 + 9; % ROI is 5 sites to account for doublon detection
row_stop = row1 + 13;
rlen = row2-row1+1;

[param_names, param_table] = get_batch_params(folderName); %param_table has extra zero at the end for some reason
paramInd = find(strcmp(param_names, paramVar));
all_phase_values = param_table(:, paramInd);
all_phase_length = length(all_phase_values);
atomMatrix = zeros(248,248);
unique_phase_vals = unique(all_phase_values);
phase_len = numel(unique(all_phase_values));

% Restrict to the lines that have the good evolution dur:
param2Ind = find(strcmp(param_names, paramVar2));
alltvalues = param_table(:, param2Ind);
good_index = alltvalues == evolution_dur;
all_phase_values = all_phase_values(good_index);
all_phase_len = length(all_phase_values);

doublonInd = find(strcmp(param_names, 'doublon_ramp_dur'));
doublon_ramp_dur_all = param_table(:, doublonInd);
doublon_ramp_dur_list = unique(doublon_ramp_dur_all);


%% Extract data at short times

phase_len_aux = all_phase_len / numel(doublon_ramp_dur_list);

basis = [1 1 0 0; 1 0 1 0; 1 0 0 1; 0 1 1 0; 0 1 0 1; 0 0 1 1; 2 0 0 0; 0 2 0 0; 0 0 2 0; 0 0 0 2];
basis_doublon = [NaN NaN NaN NaN NaN; 0 1 0 1 0; 0 1 0 0 1; NaN NaN NaN NaN NaN; 0 0 1 0 1; 0 0 0 0 0; 1 1 0 0 0; 0 1 1 0 0; 0 0 1 1 0; 0 0 0 1 1];
basis_no_doublon = [0 1 1 0 0; 0 1 0 1 0; 0 1 0 0 1; 0 0 1 1 0; 0 0 1 0 1; 0 0 0 1 1; NaN NaN NaN NaN NaN; NaN NaN NaN NaN NaN; NaN NaN NaN NaN NaN; NaN NaN NaN NaN NaN];
state_hist = zeros(phase_len_aux, size(basis_doublon,1)+1);
state_hist_doublon = zeros(phase_len_aux, size(basis_no_doublon,1)+1);

% case 1: difference between atom positions = 0 (doublon)
% case 2: difference between atom positions = 1
% case 3: difference between atom positions > 1
Npost = zeros(phase_len_aux, 3, 2);
Nrealizations = zeros(phase_len_aux, 2);
bo = zeros(phase_len_aux, rlen, 3, 2);
bo_density = zeros(phase_len_aux, rlen, 3, 2);
corr_diff1_a = zeros(rlen, rlen, phase_len_aux);
corr_diff2_a = zeros(rlen, rlen, phase_len_aux);
corr_diff0_b = zeros(rlen, rlen, phase_len_aux);
corr_diff2_b = zeros(rlen, rlen, phase_len_aux);
corr_diff1_a_density = zeros(rlen, rlen, phase_len_aux);
corr_diff2_a_density = zeros(rlen, rlen, phase_len_aux);
corr_diff0_b_density = zeros(rlen, rlen, phase_len_aux);
corr_diff2_b_density = zeros(rlen, rlen, phase_len_aux);

for dd = 1:numel(doublon_ramp_dur_list)
    doublon_ramp_dur = doublon_ramp_dur_list(dd);

    start = 1+phase_len_aux*(dd-1); stop = phase_len_aux*dd;

    for jj = 1:(stop-start+1)

        listing = dir(fullfile(folderName, [prefix '*' num2str(start+jj-1,'%03.f') 'atomMatrix.mat']));
        Nshots = size(listing,1);
        Nrealizations(jj, dd) = (col2-col1+1) * Nshots;
        for i = 1:Nshots
            load(fullfile(folderName, listing(i).name));
            for col = col1:col2
                if sum(atomMatrix(row1:row2, col)) == 2
                    distr = atomMatrix(row_start:row_stop, col)';                           
                    if doublon_ramp_dur > 0
                        state = find(ismember(basis_doublon, distr, 'rows'));
                        if isempty(state)
                            state_hist_doublon(jj,end) = state_hist_doublon(jj,end) + 1;
                        else
                            state_hist_doublon(jj,state) = state_hist_doublon(jj,state) + 1; 
                        end
                    else
                        state = find(ismember(basis_no_doublon, distr, 'rows'));
                        if isempty(state)
                            state_hist(jj,end) = state_hist(jj,end) + 1;
                        else
                            state_hist(jj,state) = state_hist(jj,state) + 1; 
                        end
                    end

                end
            end
        end
    end
end


%% Combine histograms

state_hist_total = zeros(phase_len_aux, size(basis,1)+1);
state_hist_total(:,[1 4 6]) = 2*state_hist(:,[1 4 6]);
state_hist_total(:,7:10) = 2*state_hist_doublon(:,7:10);
state_hist_total(:,[2 3 5]) = state_hist(:,[2 3 5]) + state_hist_doublon(:,[2 3 5]);
state_hist_total(:,end) = state_hist(:,end) + state_hist_doublon(:,end);

% compress
state_hist_compressed = zeros(phase_len, size(basis,1)+1);
phase_values = all_phase_values(1:size(state_hist_total,1));
for tt = 1:phase_len
    state_hist_compressed(tt,:) = mean( state_hist_total(phase_values==unique_phase_vals(tt),:), 1); 
end

pn = state_hist_compressed ./ sum(state_hist_compressed,2);
pn_err = sqrt(pn .* (1 - pn) ./ sum(state_hist_compressed,2));


%% Plot (redo)

%NOTE: tilt is increasing towards the left on camera, but shown as
%increasing to right in paper --> flip left and right before plotting
b_idx_1 = find(ismember(basis, [1 1 0 0], 'rows'));
b_idx_2 = find(ismember(basis, [0 0 1 1], 'rows'));

if plot_figure
    figure()
    tiledlayout(1,1)
    
    nexttile
    hold on
    % errorbar(unique_phase_vals, pn(:, b_idx_1), pn_err(:, b_idx_1), 'o', ...
    %     'CapSize', 0, 'DisplayName', num2str(basis(b_idx_1, :)))    
    % errorbar(unique_phase_vals, pn(:, b_idx_2), pn_err(:, b_idx_2), 'o', ...
    %     'CapSize', 0, 'DisplayName', num2str(basis(b_idx_2, :)))
    errorbar(unique_phase_vals, pn(:, b_idx_1), pn_err(:, b_idx_1), 'o', ...
        'CapSize', 0, 'DisplayName', num2str( flip(basis(b_idx_1, :)) ) )    
    errorbar(unique_phase_vals, pn(:, b_idx_2), pn_err(:, b_idx_2), 'o', ...
        'CapSize', 0, 'DisplayName', num2str( flip(basis(b_idx_2, :)) ) )
    ylim([0 0.2])
    ylabel('Probability')
    xlabel('\theta (\pi)')

    legend('Location','best')
end
 
